{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DRM experiment simulation - extended model\n",
    "\n",
    "The Deese-Roediger-McDermott task is a classic way to measure memory distortion. This notebook tries to recreate the human results in VAE and AE models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Installation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install tensorflow==2.11.0\n",
    "!pip install tensorflow-datasets\n",
    "!pip install tfds-nightly\n",
    "!pip install scikit-learn --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n",
    "import tensorflow_datasets as tfds\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from config import DRM_lists, lures\n",
    "from data_preparation import *\n",
    "from generative_model import *\n",
    "from hopfield_models import *\n",
    "from itertools import groupby\n",
    "from operator import itemgetter\n",
    "from math import sqrt\n",
    "import random\n",
    "\n",
    "tf.random.set_seed(1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VAE pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "x_train, vectorizer = prepare_data(max_df=0.1, min_df=0.0005, ids=False)\n",
    "\n",
    "# vae = train_vae(x_train, vectorizer, eps=100, ld=300, beta=0.001, batch=128, l1_value=0.01,\n",
    "#                 model_save_path='300ld_100eps_0.001beta_128batch_0.01l1.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_weights_path = '300ld_100eps_0.001beta_128batch_0.01l1.h5'\n",
    "vae = load_vae(vectorizer, ld=300, beta=0.001, \n",
    "               model_weights_path=model_weights_path)\n",
    "\n",
    "_, id_vectorizer = prepare_data(max_df=0.1, min_df=0.0005, ids=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Extended model - store latent codes plus IDs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create lists with additional unique spatiotemporal features:\n",
    "def create_experiences(num=None):\n",
    "    if num is None:\n",
    "        experiences = [' '.join([f'id_{n}'] + l) for n, l in enumerate(list(DRM_data.values()))]\n",
    "    else:\n",
    "        experiences = [' '.join([f'id_{n}'] + l[0:num]) for n, l in enumerate(list(DRM_data.values()))]\n",
    "    print(\"Example list with unique spatiotemporal feature:\")\n",
    "    print(experiences[0])\n",
    "    return experiences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# set latent dimension\n",
    "ld = 300    \n",
    "\n",
    "def get_latent_code(sent):\n",
    "    encoded = vae.encoder.predict(vectorizer.transform([sent]))[0]\n",
    "    return encoded\n",
    "\n",
    "def encode_patterns(experiences):\n",
    "    # create MHN with dimension ld + id_vectorizer vocabulary size\n",
    "    net = ContinuousHopfield(ld + len(id_vectorizer.vocabulary_.keys()),\n",
    "                             beta=100,\n",
    "                             do_normalization=False)\n",
    "    \n",
    "    patterns = []\n",
    "    for test_text in experiences:\n",
    "        # get latent code for list\n",
    "        latent = get_latent_code(test_text)\n",
    "        # flatten latent to 1D array\n",
    "        latent = latent.flatten()  \n",
    "        # get vector representing unique spatiotemporal context\n",
    "        id_counts = id_vectorizer.transform([test_text.split()[0]]).toarray()\n",
    "        # check unique spatiotemporal context\n",
    "        print(id_vectorizer.inverse_transform(id_counts))\n",
    "        id_counts = id_counts.flatten() \n",
    "        pattern = list(latent) + list(id_counts)\n",
    "        patterns.append(pattern)\n",
    "    \n",
    "    patterns = [np.array(p).reshape(-1, 1) for p in patterns]\n",
    "    net.learn(np.array(patterns))\n",
    "    return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test recall:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def hybrid_recall(test_text, net):\n",
    "    latent = np.full((1, ld), 0)\n",
    "    latent = latent.flatten()  # flatten latent to 1D array\n",
    "    id_counts = id_vectorizer.transform([test_text]).toarray()\n",
    "    id_counts = id_counts.flatten()  # flatten id_counts to 1D array\n",
    "    pattern = list(latent) + list(id_counts)  # concatenating two lists\n",
    "\n",
    "    memory = net.retrieve(np.array(pattern).reshape(-1, 1))\n",
    "    \n",
    "    decoded = vae.decoder.predict(memory[0:ld].reshape((1,ld)))\n",
    "    top_words = [(word_lookup[index], decoded[0][index]) for index in np.argsort(-decoded)[0]][0:15]\n",
    "\n",
    "    unpredictable_component = id_vectorizer.inverse_transform(np.array(memory[ld:]).reshape((1, 4226)))\n",
    "    recalled_words = [tuple((unpredictable_component[0][0], 1))] + list(top_words)\n",
    "    return recalled_words\n",
    "\n",
    "def hybrid_plot(ax, terms, scores, clrs, lure_word):\n",
    "    ax.bar(terms, scores, color=clrs, alpha=0.5)\n",
    "    ax.axhline(y=0.5, color='grey', linestyle='--') # Add a dashed line at y=0.5\n",
    "    ax.set_ylabel('Recall score', fontsize=22)\n",
    "    ax.set_title(f\"Lure word '{lure_word}'\", fontsize=22)\n",
    "    plt.sca(ax)\n",
    "    plt.xticks(rotation=90, fontsize=22)\n",
    "\n",
    "word_lookup = {v:k for k,v in vectorizer.vocabulary_.items()}\n",
    "\n",
    "fig, axs = plt.subplots(len(lures), 1, figsize=(8, 4*len(lures)))\n",
    "fig.tight_layout(h_pad=14)\n",
    "\n",
    "experiences = create_experiences(num=None)\n",
    "net = encode_patterns(experiences)\n",
    "\n",
    "for i, ax in enumerate(axs):\n",
    "    lure = lures[i]\n",
    "    list_words = DRM_data[lures[i]] + [f'id_{i}']\n",
    "    recalled = hybrid_recall(f'id_{i}', net)\n",
    "    terms = [i[0] for i in recalled]\n",
    "    scores = [i[1] for i in recalled]\n",
    "    clrs = ['red' if x == lures[i] else 'blue' if x in list_words else 'grey' for x in terms]\n",
    "    hybrid_plot(ax, terms, scores, clrs, lure)\n",
    "\n",
    "plt.savefig('mhn_drm.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Show longer lists increase false recall:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "nums = range(3, 15)\n",
    "\n",
    "def create_experiences(num=None):\n",
    "    if num is None:\n",
    "        experiences = [' '.join([f'id_{n}'] + l) for n, l in enumerate(list(DRM_data.values()))]\n",
    "    else:\n",
    "        experiences = [' '.join([f'id_{n}'] + random.sample(l, num)) for n, l in enumerate(list(DRM_data.values())) if len(l) >= num]\n",
    "    return experiences\n",
    "\n",
    "def encode_num(num=3):\n",
    "    experiences = create_experiences(num=num)\n",
    "    net = encode_patterns(experiences)\n",
    "    return net, experiences\n",
    "\n",
    "def get_recalled_for_num(lure, net):\n",
    "    recalled = hybrid_recall(f'id_{lures.index(lure)}', net)\n",
    "    terms = [i[0] for i in recalled if i[1] > 0.5]\n",
    "    return terms\n",
    "\n",
    "\n",
    "results = []\n",
    "\n",
    "for i in range(20):\n",
    "    for num in nums:\n",
    "        print(num)\n",
    "        net, experiences = encode_num(num=num)\n",
    "        for ind, lure in enumerate(list(DRM_data.keys())):\n",
    "            print(lure)\n",
    "            if len(DRM_data[lure]) >= num:\n",
    "                recalled = get_recalled_for_num(lure, net)\n",
    "                if lure in recalled:\n",
    "                    lure_recalled_bool = 1\n",
    "                else:\n",
    "                    lure_recalled_bool = 0\n",
    "                results.append({'lure': lure,\n",
    "                               'num': num,\n",
    "                               'lure_recalled': lure_recalled_bool})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = results\n",
    "# Sort the data by 'num' first\n",
    "data.sort(key=itemgetter('num'))\n",
    "\n",
    "nums = []\n",
    "mean_values = []\n",
    "errors = []\n",
    "\n",
    "# Use groupby to group the data by 'num'\n",
    "for num, group in groupby(data, key=itemgetter('num')):\n",
    "    # Convert group to a list first\n",
    "    group = list(group)\n",
    "    \n",
    "    # Calculate the mean 'lure_recalled' for each group\n",
    "    lure_recalled_values = [item['lure_recalled'] for item in group]\n",
    "    print(len(lure_recalled_values))\n",
    "    lure_recalled_mean = sum(lure_recalled_values) / len(group)\n",
    "    \n",
    "    # Calculate the SEM for each group\n",
    "    sem = np.std(lure_recalled_values) / sqrt(len(group))\n",
    "    \n",
    "    nums.append(num)\n",
    "    mean_values.append(lure_recalled_mean)\n",
    "    errors.append(sem)\n",
    "\n",
    "# Create error bars plot\n",
    "plt.errorbar(nums, mean_values, yerr=errors, fmt='-o', capsize=5, capthick=2, label='Model data')\n",
    "\n",
    "rr_data = {3: 0.03, 6: 0.11, 9: 0.19, 12: 0.27, 15: 0.31}\n",
    "rr_nums = list(rr_data.keys())\n",
    "rr_values = list(rr_data.values())\n",
    "\n",
    "plt.plot(rr_nums, rr_values, '-o', label='Human data', color='grey')\n",
    "\n",
    "plt.xlabel('Number of associates studied', fontsize=18)\n",
    "plt.ylabel('Probability of false recall', fontsize=18)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.legend(fontsize=16)  # This adds a legend to the plot\n",
    "plt.savefig('lure_words_fraction.pdf', bbox_inches='tight', dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Statistical analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import correlation\n",
    "\n",
    "print(correlation.corr(nums, mean_values, method='spearman_rho', ci=0.95))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
